'''
Description: Test Model
Author: renlirong
Date: 2024-07-25 11:32:36
LastEditTime: 2024-09-13 11:25:45
LastEditors: renlirong
'''
import os
import cv2
import torch
import matplotlib.pyplot as plt
import json
from pathlib import Path
from ultralytics import YOLOv10

model_path = 'runs/detect/train/weights/best.pt'
test_image_folder = 'datasets/page_seg/images/train_antd'
output_folder = 'runs/test/test_output_v10_train'

os.makedirs(output_folder, exist_ok=True)

model = YOLOv10(model_path)
confidence_threshold = 0.2  

image_paths = [os.path.join(test_image_folder, img) for img in os.listdir(test_image_folder) if img.endswith(('png', 'jpg', 'jpeg'))]

for image_path in image_paths:
    img = cv2.imread(image_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    results = model(img_rgb, conf=confidence_threshold)

    detections = results[0].boxes.data 

    height, width, _ = img.shape
    detection_results = []

    component_id = 1
    for detection in detections:
        x1, y1, x2, y2, conf, cls = detection[:6]
        x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
        
        position = {
            "x": x1 / width,
            "y": y1 / height,
            "width": (x2 - x1) / width,
            "height": (y2 - y1) / height
        }
        detection_results.append({
            "id": f"component_{component_id}",
            "type": model.names[int(cls)],
            "position": position
        })

        component_id += 1
        cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
        
        label = f'{model.names[int(cls)]} {conf:.2f}'
        cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
    
    output_path = os.path.join(output_folder, os.path.basename(image_path))
    cv2.imwrite(output_path, img)

    json_file_name = os.path.splitext(os.path.basename(image_path))[0] + '.json'
    json_output_path = os.path.join(output_folder, json_file_name)
    with open(json_output_path, 'w') as f:
        json.dump(detection_results, f, indent=4)

    # 可选：显示结果图片
    # plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    # plt.show()

print('检测完成，结果已保存至', output_folder)
